Skip to content

add aten::_scaled_dot_product_flash_attention to perf model#217

Merged
ajassani merged 1 commit intomainfrom
feat/fa_variant
Jun 30, 2025
Merged

add aten::_scaled_dot_product_flash_attention to perf model#217
ajassani merged 1 commit intomainfrom
feat/fa_variant

Conversation

@ajassani
Copy link
Copy Markdown
Collaborator

No description provided.

@ajassani ajassani requested a review from Copilot June 30, 2025 14:32
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds support for flash attention in the performance model by introducing the aten::_scaled_dot_product_flash_attention operator.

  • Adds a new mapping to link the operator in torch_op_mapping.py.
  • Implements a new SDPA subclass in perf_model.py to handle flash attention-specific parameters.

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
TraceLens/PerfModel/torch_op_mapping.py Added mapping for aten::_scaled_dot_product_flash_attention to invoke the new handler.
TraceLens/PerfModel/perf_model.py Introduced a new class for flash attention with custom parameter parsing in get_param_details.

except (ValueError, TypeError):
pass
is_causal = concrete_inputs[4].lower() == 'true' if concrete_inputs[4] not in ('', 'None') else False
# scale = float(concrete_inputs[5]) if concrete_inputs[5] not in ('', 'None') else None
Copy link

Copilot AI Jun 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider removing the commented scale conversion code or fully implementing it to avoid confusion in future maintenance.

Copilot uses AI. Check for mistakes.
try:
dropout_p = float(concrete_inputs[3])
except (ValueError, TypeError):
pass
Copy link

Copilot AI Jun 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Consider logging a warning when the conversion of dropout_p fails in order to aid debugging instead of silently swallowing the error.

Suggested change
pass
logging.warning(f"Failed to convert dropout_p value '{concrete_inputs[3]}' to float.")

Copilot uses AI. Check for mistakes.
@ajassani ajassani merged commit 7f87fb8 into main Jun 30, 2025
@ajassani ajassani deleted the feat/fa_variant branch June 30, 2025 14:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants